#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @File: deps.py
# @Author: lotus163
# @Date: 2022/11/21
import datetime
from typing import Optional, Dict

from fastapi import Depends, Request
from jose import jwt
from loguru import logger

from backend.config.setting import settings
from backend.core.casbin import CasbinService
from backend.core.custom_exc import AccessTokenFail, LoginException
from backend.core.redis import MyRedis, get_login_tokens_key
from backend.core.request import RequestService
from backend.core.security import get_token
from backend.schemas import LoginUserSch, OperLogSch


async def get_redis(request: Request) -> MyRedis:
    """ redis连接对象 """
    return await request.app.state.redis


# https://www.cnblogs.com/CharmCode/p/14191112.html?ivk_sa=1024320u
def check_token(token: Optional[str] = Depends(get_token)) -> Dict:
    """ 解密并检查token """
    try:
        logger.debug("token:" + token)
        payload = jwt.decode(token=token, key=settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
        return payload
    except Exception as e:  # jwt.JWTError, jwt.ExpiredSignatureError, AttributeError
        raise AccessTokenFail(f'token已过期! -- {e}')


class LoginDepends:
    """登录需要的依赖"""

    def __init__(self, request: Request):
        self.request_service: RequestService = RequestService(request)
        self.redis: MyRedis = request.app.state.redis


class UserDepends(LoginDepends):
    """用户需要的依赖"""

    def __init__(self, request: Request, payload=Depends(check_token)):
        super().__init__(request)
        self.payload: Dict = payload
        self.login_user: Optional[LoginUserSch] = None
        #self.casbin_service: CasbinService = CasbinService(request.app.state.enforcer)

    async def get_login_user(self):
        # 从token获取uuid
        uuid = self.payload.get("login_user_key")
        # 从redis获取用户信息
        login_user_json = await self.redis.get(get_login_tokens_key(uuid))
        if login_user_json is None:
            raise LoginException("用户不存在，请重新登录")
        # 解析json
        self.login_user = LoginUserSch.parse_raw(login_user_json)

    def get_oper_log(self):
        # 获取操作类型
        oper_log_sch = OperLogSch(
            title=self.request_service.get_route_name(),
            businessType=self.__get_business_type(self.request_service.get_method()),
            method=self.request_service.get_route_name(),
            requestMethod=self.request_service.get_method(),
            operatorType=1,
            operName=self.login_user.user.userName,
            deptName=self.get_dept_name(),
            operUrl=self.request_service.get_url_path(),
            operIp=self.request_service.get_ip_addr(),
            operLocation=self.request_service.get_ip_location(),
            operParam=self.request_service.get_oper_param(),
            operTime=datetime.datetime.now()
        )
        # 存储到request的scope中，返回的时候用于读取信息，记录oper_log
        self.request_service.request.scope.setdefault("oper_log_json", oper_log_sch)

    def __get_business_type(self, request_method: str) -> int:
        if request_method == "PUT":  # 新增
            return 1
        elif request_method == "POST":  # 修改
            return 2
        elif request_method == "DELETE":  # 删除
            return 3
        else:
            return 0

    def get_dept_name(self) -> Optional[str]:
        if self.login_user.user.dept:
            return self.login_user.user.dept.deptName
        else:
            return None

    def check_permission(self):
        pass


async def get_user_depends(user_depends: UserDepends = Depends(UserDepends)) -> UserDepends:
    """检查授权并获取依赖"""
    await user_depends.get_login_user()
    # user_depends.check_permission()
    user_depends.get_oper_log()
    return user_depends


async def get_user_depends_nolog(user_depends: UserDepends = Depends(UserDepends)) -> UserDepends:
    """检查授权并获取依赖"""
    await user_depends.get_login_user()
    user_depends.check_permission()
    return user_depends
